package com.lxm.framework.mybatisplus.parser;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.parser.SqlInfo;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.lxm.framework.common.UsualConstant;
import com.lxm.framework.common.principle.StandardPrinciple;
import com.lxm.framework.mybatisplus.common.SqlConstant;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.commons.lang3.StringUtils;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Objects;

/**
 * @Author: Lys
 * @Date 2022/2/25
 * @Describe
 **/
@Slf4j
public class CustomizedSqlParser {

    /**
     * 是否拦截
     */
    private final boolean intercept;
    /**
     * 当前数据库连接
     */
    private final Connection connection;
    /**
     * 数据库；类型
     */
    private final DbType dbType;
    /**
     * 拼接parameter之后的sql
     */
    private final String originalSql;
    /**
     * 数据库转义前缀
     */
    private String escapeStart = "`";
    /**
     * 数据库转义后缀
     */
    private String escapeEnd = "`";
    /**
     * 基本字段查询sql
     */
    private String baseFiledSql;

    private StandardPrinciple principle;

    public CustomizedSqlParser(boolean intercept, Connection connection, String originalSql, StandardPrinciple principle) {
        this.intercept = intercept;
        this.connection = connection;
        this.dbType = DbType.MYSQL;
        this.originalSql = originalSql;
        this.principle = principle;
        this.beforeParser();
    }

    private void beforeParser() {
        String schema = null;
        try {
            schema = connection.getCatalog();
        } catch (SQLException e) {
            log.debug("sql parseer get mysql schema faild.");
        }
        if (StringUtils.isBlank(schema)) {
            try {
                String jdbcUrl = connection.getMetaData().getURL();
                if (StringUtils.isNotBlank(jdbcUrl)) {
                    schema = StringUtils.substringBefore(StringUtils.substringAfterLast(jdbcUrl, "/"), "?");
                }
            } catch (SQLException e) {
                log.debug("sql parseer get mysql schema faild.");
            }
        }
        if (StringUtils.isBlank(schema)) {
            schema = UsualConstant.DEFAULT_MYSQL_DB_NAME;
        }
        this.baseFiledSql = "select ifnull(sum(case column_name when 'tenant_id' then 1 when 'deleted' then 2 end), 0) from information_schema.`COLUMNS` where table_schema = '" + schema + "' and table_name = '%s' and column_name in ('tenant_id', 'deleted')";
    }

    /**
     * <p>
     * 解析 SQL 方法
     * </p>
     *
     * @return SQL 信息
     */

    public SqlInfo parser(String osql) {
        try {
            Statements statements = CCJSqlParserUtil.parseStatements(osql);
            StringBuilder sqlStringBuilder = new StringBuilder();
            int i = 0;
            for (Statement statement : statements.getStatements()) {
                if (null != statement) {
                    if (i++ > 0) {
                        sqlStringBuilder.append(';');
                    }
                    sqlStringBuilder.append(this.processParser(statement).getSql());
                }
            }
            if (sqlStringBuilder.length() > 0) {
                return SqlInfo.newInstance().setSql(sqlStringBuilder.toString());
            }
        } catch (JSQLParserException e) {
            throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, osql);
        }
        return null;
    }

    /**
     * <p>
     * 执行 SQL 解析
     * </p>
     *
     * @param statement JsqlParser Statement
     * @return SqlInfo
     */
    private SqlInfo processParser(Statement statement) {
        if (statement instanceof Insert) {
            this.processInsert((Insert) statement);
        } else if (statement instanceof Select) {
            this.processSelectBody(((Select) statement).getSelectBody());
        } else if (statement instanceof Update) {
            this.processUpdate((Update) statement);
        } else if (statement instanceof Delete) {
            this.processDelete((Delete) statement);
        }
        return SqlInfo.newInstance().setSql(statement.toString());
    }

    /**
     * 处理物理删除
     *
     * @param delete 物理删除
     */
    private void processDelete(Delete delete) {
        //表
        Expression expression = processTable(delete.getTable(), delete.getWhere());
        delete.setWhere(expression);
        if (delete.getTables() != null) {
            delete.getTables().forEach(t -> {
                Expression e = processTable(t, delete.getWhere());
                delete.setWhere(e);
            });
        }
        //关联
        processJoins(delete.getJoins());
        //条件
        escapeExpression(delete.getWhere());
        //排序
        processOrderBys(delete.getOrderByElements());
    }

    /**
     * 处理更新
     *
     * @param update 更新
     */
    private void processUpdate(Update update) {
        if (update.getTable() != null) {
            Expression e = processTable(update.getTable(), update.getWhere());
            update.setWhere(e);
        }
        processFromItem(update.getFromItem());
        processJoins(update.getJoins());
        if (update.getSelect() != null) {
            processSelectBody(update.getSelect().getSelectBody());
        }
        if (update.getColumns() != null) {
            update.getColumns().forEach(this::escapeExpression);
        }
        if (update.getExpressions() != null) {
            update.getExpressions().forEach(this::escapeExpression);
        }
        escapeExpression(update.getWhere());
        processOrderBys(update.getOrderByElements());
    }

    private void processInsert(Insert insert) {
        Table table = insert.getTable();
        table.setName(wrapTableOrColumn(table.getName()));
        if (insert.getColumns() != null) {
            insert.getColumns().forEach(this::escapeExpression);
        }
    }

    /**
     * 转义where
     *
     * @param expression 表达式
     */
    private void escapeExpression(Expression expression) {
        if (expression instanceof Column) {
            Column column = (Column) expression;
            column.setColumnName(wrapTableOrColumn(column.getColumnName()));
        } else if (expression instanceof BinaryExpression) {
            BinaryExpression be = (BinaryExpression) expression;
            escapeExpression(be.getLeftExpression());
            escapeExpression(be.getRightExpression());
        } else if (expression instanceof Between) {
            Between between = (Between) expression;
            escapeExpression(between.getLeftExpression());
        } else if (expression instanceof InExpression) {
            InExpression ie = (InExpression) expression;
            escapeExpression(ie.getLeftExpression());
            if (ie.getRightItemsList() instanceof FromItem) {
                processFromItem((FromItem) ie.getRightItemsList());
            }
        } else if (expression instanceof CaseExpression) {
            CaseExpression ce = (CaseExpression) expression;
            escapeExpression(ce.getSwitchExpression());
            if (ce.getWhenClauses() != null) {
                ce.getWhenClauses().forEach(whenClause -> {
                    escapeExpression(whenClause.getWhenExpression());
                    escapeExpression(whenClause.getThenExpression());
                });
            }
            escapeExpression(ce.getElseExpression());
        } else if (expression instanceof IsNullExpression) {
            IsNullExpression ine = (IsNullExpression) expression;
            escapeExpression(ine.getLeftExpression());
        } else if (expression instanceof Function) {
            Function function = (Function) expression;
            if (function.getParameters() != null) {
                for (Expression iteme : function.getParameters().getExpressions()) {
                    if (iteme instanceof Column && ((Column) iteme).getTable() == null) {
                        continue;
                    }
                    escapeExpression(iteme);
                }
            }
        } else if (expression instanceof NotExpression) {
            NotExpression note = (NotExpression) expression;
            escapeExpression(note.getExpression());
        } else if (expression instanceof ExistsExpression) {
            ExistsExpression ee = (ExistsExpression) expression;
            if (ee.getRightExpression() instanceof FromItem) {
                processFromItem((FromItem) ee.getRightExpression());
            } else {
                escapeExpression(ee.getRightExpression());
            }
        }
    }

    /**
     * 处理子查询等
     *
     * @param fromItem 子查询
     */
    private void processFromItem(FromItem fromItem) {
        if (fromItem instanceof SubJoin) {
            var subJoin = (SubJoin) fromItem;
            processJoins(subJoin.getJoinList());
            if (subJoin.getLeft() != null) {
                processFromItem(subJoin.getLeft());
            }
        } else if (fromItem instanceof SubSelect) {
            var subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelectBody(subSelect.getSelectBody());
            }
        } else if (fromItem instanceof ValuesList) {
            log.debug("Perform a subquery, if you do not give us feedback");
        } else if (fromItem instanceof LateralSubSelect) {
            var lateralSubSelect = (LateralSubSelect) fromItem;
            if (lateralSubSelect.getSubSelect() != null) {
                var subSelect = lateralSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
                }
            }
        }
    }

    /**
     * 处理联接语句
     *
     * @param joins 关联
     */
    private void processJoins(List<Join> joins) {
        if (joins == null || joins.size() == 0) {
            return;
        }
        joins.forEach(join -> {
            if (join.getRightItem() instanceof Table) {
                var fromTable = (Table) join.getRightItem();
                Expression expression = processTable(fromTable, join.getOnExpression());
                join.setOnExpression(expression);
            }
            escapeExpression(join.getOnExpression());
            processFromItem(join.getRightItem());
        });
    }

    /**
     * 处理查询
     *
     * @param selectBody 查询
     */
    private void processSelectBody(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            var withItem = (WithItem) selectBody;

            if (withItem.getSubSelect() != null && withItem.getSubSelect().getSelectBody() != null) {
                var subSelectBody = withItem.getSubSelect().getSelectBody();
                processSelectBody(subSelectBody);
            }
        } else {
            var operationList = (SetOperationList) selectBody;
            if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
                var plainSelects = operationList.getSelects();
                plainSelects.forEach(this::processSelectBody);
            }
        }
    }

    /**
     * 处理 PlainSelect
     *
     * @param plainSelect select
     */
    private void processPlainSelect(PlainSelect plainSelect) {
        plainSelect.getSelectItems().forEach(selectItem -> {
            if (selectItem instanceof SelectExpressionItem) {
                SelectExpressionItem item = (SelectExpressionItem) selectItem;
                if (item.getExpression() instanceof FromItem) {
                    processFromItem((FromItem) item.getExpression());
                } else if (item.getExpression() instanceof Parenthesis) {
                    Parenthesis parenthesis = (Parenthesis) item.getExpression();
                    escapeExpression(parenthesis.getExpression());
                } else {
                    escapeExpression(item.getExpression());
                }
            }
        });
        var fromItem = plainSelect.getFromItem();
        if (fromItem instanceof Table) {
            var fromTable = (Table) fromItem;
            Expression expression = processTable(fromTable, plainSelect.getWhere());
            plainSelect.setWhere(expression);
        } else {
            processFromItem(fromItem);
        }
        processJoins(plainSelect.getJoins());
        escapeExpression(plainSelect.getWhere());
        escapeExpression(plainSelect.getHaving());
        processOrderBys(plainSelect.getOrderByElements());
    }

    /**
     * 处理table
     *
     * @param table     表
     * @param condition 条件
     * @return Expression
     */
    private Expression processTable(Table table, Expression condition) {
        var tenantCol = SqlConstant.TID;
        var delCol = SqlConstant.DELETED;
        var existTenant = 1;
        var existDeleted = 2;
        var bothExisted = 3;
        var concatFlag = ".";
        table.setName(wrapTableOrColumn(table.getName()));
        var alias = table.getAlias();
        int binary = getTableBaseFiled(table.getName());
        if (!intercept) {
            if (alias == null) {
                if ((Objects.isNull(condition) || !StringUtils.containsIgnoreCase(condition.toString(), tenantCol)) && (binary == existTenant || binary == bothExisted) && this.principle != null && this.principle.getTenantId() >= 0) {
                    condition = builderTenantExpression(condition, table);
                }
                if ((Objects.isNull(condition) || !StringUtils.containsIgnoreCase(condition.toString(), delCol)) && (binary == existDeleted || binary == bothExisted)) {
                    condition = builderDeletedExpression(condition, table);
                }
            } else {
                if ((Objects.isNull(condition) || (!StringUtils.containsIgnoreCase(condition.toString(), alias.getName() + concatFlag + tenantCol) && !StringUtils.containsIgnoreCase(condition.toString(), alias.getName() + concatFlag + this.escapeStart + tenantCol + this.escapeEnd))) && (binary == existTenant || binary == bothExisted) && this.principle != null && this.principle.getTenantId() >= 0) {
                    condition = builderTenantExpression(condition, table);
                }
                if ((Objects.isNull(condition) || (!StringUtils.containsIgnoreCase(condition.toString(), alias.getName() + concatFlag + delCol) && !StringUtils.containsIgnoreCase(condition.toString(), alias.getName() + concatFlag + this.escapeStart + delCol + this.escapeEnd))) && (binary == existDeleted || binary == bothExisted)) {
                    condition = builderDeletedExpression(condition, table);
                }
            }
        }
        return condition;
    }

    /**
     * 处理逻辑删除条件
     *
     * @param expression 表达式
     * @param table      表
     * @return Expression
     */
    private Expression builderDeletedExpression(Expression expression, Table table) {
        var equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(this.getDeletedAliasColumn(table));
        equalsTo.setRightExpression(new LongValue(0L));
        //加入判断防止条件为空时生成 "and null" 导致查询结果为空
        if (expression == null) {
            return equalsTo;
        }
        return new AndExpression(equalsTo, expression);
    }

    /**
     * 删除字段别名设置<br>
     * tableName.tenantId 或 tableAlias.tenantId
     *
     * @param table 表
     * @return Column
     */
    private Column getDeletedAliasColumn(Table table) {
        return new Column(table, "deleted");
    }

    /**
     * 处理租户条件
     *
     * @param expression 表达式
     * @param table      表
     * @return Expression
     */
    private Expression builderTenantExpression(Expression expression, Table table) {
        var equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(this.getTenantAliasColumn(table));
        equalsTo.setRightExpression(new LongValue(this.principle.getTenantId()));
        //加入判断防止条件为空时生成 "and null" 导致查询结果为空
        if (expression == null) {
            return equalsTo;
        }
        return new AndExpression(equalsTo, expression);
    }

    /**
     * 租户字段别名设置<br>
     * tableName.tenantId 或 tableAlias.tenantId
     *
     * @param table 表
     * @return Column
     */
    private Column getTenantAliasColumn(Table table) {
        return new Column(table, "tenant_id");
    }

    /**
     * 查询是否存在租户id，deleted字段
     *
     * @param tableName 表名
     * @return 1存在租户id，2存在deleted，3前两者都存在
     */
    private Integer getTableBaseFiled(String tableName) {
        int binary = 0;
        if (this.connection != null) {
            try (PreparedStatement statement = connection.prepareStatement(String.format(this.baseFiledSql, StringUtils.replaceEach(StringUtils.trim(tableName), new String[]{escapeStart, escapeEnd}, new String[]{"", ""})))) {
                try (ResultSet resultSet = statement.executeQuery()) {
                    if (resultSet.next()) {
                        binary = resultSet.getInt(1);
                    }
                }
            } catch (Exception e) {
                log.error("sql parser get table basefiled error", e);
            }
        }
        return binary;
    }

    /**
     * 处理排序
     *
     * @param orderByElements 排序列表
     */
    private void processOrderBys(List<OrderByElement> orderByElements) {
        if (orderByElements == null || orderByElements.size() == 0) {
            return;
        }
        orderByElements.forEach(orderByElement -> escapeExpression(orderByElement.getExpression()));
    }

    /**
     * 转义数据库
     *
     * @param str 例如 表名、列名
     * @return String
     */
    private String wrapTableOrColumn(String str) {
        str = StringUtils.trim(str);
        str = StringUtils.replaceEach(str, new String[]{"`", "[", "]"}, new String[]{"", "", ""});
        str = this.escapeStart + str + this.escapeEnd;
        return str;
    }
}
